-
-
Notifications
You must be signed in to change notification settings - Fork 55
Fix JAX compatibility issues in Job Search III lecture #687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Updated mccall_model_with_sep_markov.md to fix several JAX-related issues: - Refactored vfi() to return only v_final instead of tuple, making it more consistent with VFI pattern - Removed separate successive_approx() function and integrated iteration logic directly into vfi() - Fixed JAX decorators: changed @jit to @jax.jit and @partial(jit, ...) to @partial(jax.jit, ...) - Rewrote get_reservation_wage() to use jnp.argmax() instead of jnp.where() to avoid JAX concretization errors in JIT compilation - Updated all vfi() call sites to explicitly compute policy with get_greedy(v_star, model) - Removed @jit decorators from T() and get_greedy() functions (not needed) Also improved wording in mccall_model_with_separation.md for clarity. Tested: Converted to Python and ran successfully without errors. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Key Fix: get_reservation_wage() FunctionThe most critical fix in this PR addresses a JAX concretization error. The original implementation used: accept_indices = jnp.where(σ == 1)[0]
if len(accept_indices) == 0:
return jnp.inf
return w_vals[accept_indices[0]]This fails in JIT-compiled functions because:
The new implementation uses: first_accept_idx = jnp.argmax(σ)
return jnp.where(jnp.any(σ), w_vals[first_accept_idx], jnp.inf)This works because:
|
Refactoring Notes: vfi() FunctionThe Before:
After:
Benefits:
All 7 call sites in the lecture were updated to explicitly call |
Testing DetailsThe changes were thoroughly tested:
All fixes are backwards-compatible in terms of functionality - the lecture produces the same numerical results and plots as before. |
|
📖 Netlify Preview Ready! Preview URL: https://pr-687--sunny-cactus-210e3e.netlify.app (6573f2a) 📚 Changed Lecture Pages: mccall_model_with_sep_markov, mccall_model_with_separation |
Updated the McCall model with separation to use h = u(c) + β * sum_w v_u(w) q(w) as the continuation value, matching the notation from the basic McCall model lecture. This makes the progression between lectures more intuitive for readers: - Basic model: h = c + β * sum_w v*(w) q(w) - Separation model: h = u(c) + β * sum_w v_u(w) q(w) Key changes: - Replaced scalar d with h throughout mathematical derivations - Updated closed-form expression for v_e(w) to use h - Modified iteration algorithm to solve for h instead of d - Simplified Bellman equations using h notation - Updated all code functions (compute_v_e, update_h, solve_model) - Changed plots and comments to reference h This improves consistency across the job search lecture series and makes the mathematical structure clearer by explicitly showing the continuation value includes both current utility and discounted future value. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Standardized continuation value notationI've updated the Job Search II lecture (McCall model with separation) to use more consistent notation with the basic McCall model. Changes made:
Benefits:
All code has been updated and tested successfully. The mathematical derivations and implementation now align better with the pedagogical flow of the lecture series. |
The JAX library is already included in the base environment and doesn't need to be explicitly installed, which was causing build failures. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
The JAX library is already included in the base environment and doesn't need to be explicitly installed, which was causing build failures. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
Removed 'torchaudio' from the installation command for PyTorch.
|
📖 Netlify Preview Ready! Preview URL: https://pr-687--sunny-cactus-210e3e.netlify.app (7f7d6a5) 📚 Changed Lecture Pages: mccall_model_with_sep_markov, mccall_model_with_separation |
|
@jstac this quick fix looks good to go. All execution is passing and the two new lectures are building plots etc. with no Warnings. |
|
Thanks @mmcky !! I'll check the lectures and merge if they are all good. |
|
Are we still using pytorch btw @mmcky ? ripgrep says no ... |
|
Ah, sorry, I see you just removed it! |
|
Thanks @mmcky Another reason to break up these big lecture series, isn't it? The issues in the build were nothing to do with my PRs. I had a pretty stressful morning trying to get the new lectures live and eventually giving up 😬 |
|
Sorry for your morning @jstac. Infrastructure / install issues with JAX are stressful! I am about to run a full test build on another branch -- but there must have been a nightly release of pytorch that has a pretty significant change in it. |
@mmcky this is from perplexity: Yes, you can use NumPyro without PyTorch. NumPyro is a lightweight probabilistic programming library that uses JAX as its backend for automatic differentiation and JIT compilation. It does not depend on PyTorch at all. In fact, NumPyro is designed to leverage JAX's NumPy-compatible API and supports running on CPU, GPU, or TPU through JAX. |
|
Sorry @jstac used the wrong name, should be |
…gma to gamma and use glue for figures
Updated the McCall model with separation lecture with the following changes:
Key changes:
- Changed utility function parameter from σ (sigma) to γ (gamma)
- Moved γ default value from utility function to Model class (γ: float = 2.0)
- Updated all functions (compute_v_e, update_h, solve_model) to pass γ parameter
- Simplified model unpacking to use tuple unpacking directly (e.g., α, β, γ, c, w, q = model)
- Replaced static PNG figures with myst-nb glue functionality
- Added glue import and glue() calls in exercise solutions
- Converted {figure} directives to {glue:figure} directives for dynamic figure generation
Benefits:
- More consistent parameter naming (gamma is standard for CRRA utility)
- Better code organization with parameter defaults in Model class
- Cleaner unpacking syntax
- Dynamic figure generation eliminates need for static PNG files
- Figures automatically stay in sync with code
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
Additional updates to
|
|
📖 Netlify Preview Ready! Preview URL: https://pr-687--sunny-cactus-210e3e.netlify.app (deb99b1) 📚 Changed Lecture Pages: mccall_model_with_sep_markov, mccall_model_with_separation |
Removed static PNG files that are now dynamically generated using myst-nb glue: - mccall_resw_alpha.png - mccall_resw_beta.png - mccall_resw_c.png These figures are now generated from the exercise solution code and displayed via glue:figure directives, eliminating the need for static files and ensuring figures always match the code. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
|
Me and my new best buddy Claude used glue to get rid of some static assets -- png files. Deleting files and dependencies is deeply satisfying. No more worrying about whether the png files are out of sync. |
|
📖 Netlify Preview Ready! Preview URL: https://pr-687--sunny-cactus-210e3e.netlify.app (4e6859a) 📚 Changed Lecture Pages: mccall_model_with_sep_markov, mccall_model_with_separation |
|
📖 Netlify Preview Ready! Preview URL: https://pr-687--sunny-cactus-210e3e.netlify.app (35aa16d) 📚 Changed Lecture Pages: mccall_model_with_sep_markov, mccall_model_with_separation |
Summary
This PR fixes several JAX-related compatibility issues in the Job Search III: Search with Separation and Markov Wages lecture (
mccall_model_with_sep_markov.md).Changes Made
Main Fixes
vfi()function: Now returns onlyv_finalinstead of a tuple(v_star, σ_star), making the implementation cleaner and more consistent with typical VFI patternssuccessive_approx()function: Integrated the iteration logic directly intovfi()to simplify the codebaseget_reservation_wage()function: Rewrote to usejnp.argmax()instead ofjnp.where()to avoid JAX concretization errors that occur during JIT compilation@jitto@jax.jit@partial(jit, ...)to@partial(jax.jit, ...)@jitdecorators fromT()andget_greedy()functionsvfi()call sites: Added explicitσ_star = get_greedy(v_star, model)calls after eachvfi()invocation (7 locations total)Additional Improvement
mccall_model_with_separation.mdfor better clarityTesting
✅ Converted the markdown file to Python using
jupytext✅ Ran the Python script successfully without errors
✅ Verified all MyST markdown structure is valid
✅ Confirmed jupytext can parse and convert the file
Technical Details
The main issue was that
jnp.where()with dynamic size cannot be used inside JIT-compiled functions, causing aConcretizationTypeError. The fix usesjnp.argmax()to find the first acceptance index, which works correctly with JAX's tracing mechanism.🤖 Generated with Claude Code